import numpy as np
import itertools
from scipy import linalg
from enum import Enum
from scipy.stats import entropy
#import numba
#from sklearn.base import TransformerMixin, BaseEstimator





class LegendreDecomposition1:
    def __init__(self, core_size=2, solver='ng',
                 tol=1e-5, max_iter=10, learning_rate=0.1,
                 random_state=None, shuffle=False, verbose=0):
        self.core_size = core_size
        self.solver = solver
        self.tol = tol
        self.max_iter = max_iter
        self.learning_rate = learning_rate
        self.random_state = random_state
        self.shuffle = shuffle
        self.verbose = verbose

        if self.verbose:
            np.set_printoptions(threshold=200)

    def fit_transform(self, P, coordinates, coordinates1, coordinates_complement,ori_theta):

        self.theta,self.c = self._legendre_decomposition(P,coordinates, coordinates1, coordinates_complement,ori_theta)
        Q = self._compute_Q(self.theta) * P.sum()
        self.reconstruction_err_ = self._calc_rmse(P, Q)

        P_flat = P.flatten()
        Q_flat = Q.flatten()
        P_flat = P_flat / np.sum(P_flat)
        Q_flat = Q_flat / np.sum(Q_flat)
        variance = np.var(P_flat)
        print("Variance of the vector:", variance)
        self.kl_div = entropy(P_flat, Q_flat, base=np.e)

        #print("KL Divergence:", self.kl_div)
        return Q

    def _calc_rmse(self, P, Q):
        print('rmse=', np.sqrt(np.mean(np.square(P - Q))))
        ans = np.sqrt(np.sum((P - Q) ** 2))/np.sqrt(np.sum(P ** 2))
        return ans

    def _normalizer(self, P):
        return P / np.sum(P)

    def _initialize(self,ori_theta):


        #theta = np.zeros(self.shape)
        theta=ori_theta
        #theta = 0.4*np.ones(self.shape)

        return theta


    def _fit_gradient_descent(self, P, beta,beta1,beta_complement):

        theta = self._initialize()
        self.eta_hat = self._compute_eta(P)
        self.res = 0.
        if self.verbose:
            print("\n\n============= theta =============")
            print(theta)
            print("\n\n============= eta_hat =============")
            print(self.eta_hat)

        for n_iter in range(self.max_iter):
            eta = self._compute_eta(self._compute_Q(theta))
            if self.verbose:
                print("\n\n============= iteration: {}, eta =============".format(n_iter))
                print(eta)

            prev_res = self.res
            self.res = self._compute_residual(eta, beta,beta_complement)
            if self.verbose:
                print("n_iter: {}, Residual: {}".format(n_iter, self.res))

            for v in beta:
                # \theta_v \gets \theta_v - \epsilon \times (\eta_v - \hat{\eta_v})
                grad = self._compute_eta(self._compute_Q(theta)) - self.eta_hat
                theta[v] -= self.learning_rate * grad[v]

            if self.verbose:
                print("\n\n============= iteration: {}, theta =============".format(n_iter))
                print(theta)

        return theta
    def _compute_residual(self, eta, beta,beta1,beta_complement): 

        s_length = len(beta1)

       
        s = np.zeros(s_length)

      
        for i, v in enumerate(beta):
            s[i] = eta[v] - self.eta_hat[v]

      
        s[-1] = np.sum([eta[s] - self.eta_hat[s] for s in beta_complement])

        res=np.sqrt(np.mean(np.sum(s ** 2)))

        return res

    def _fit_natural_gradient(self, P, beta,beta1,beta_complement,ori_theta):
        theta = self._initialize(ori_theta)
        self.eta_hat = self._compute_eta(P)
        self.res = 0.

        size_P=P.size

        P_without_first_element_indices = beta+beta_complement

        theta_vec = np.array([theta[v] for v in beta1])
        if self.verbose:
            print("\n\n============= theta =============")
            print(theta)
            print("\n\n============= eta_hat =============")
            print(self.eta_hat)

        for n_iter in range(self.max_iter):
            eta = self._compute_eta(self._compute_Q(theta))

            Q1 = self._compute_Q(theta)
            P_flat = P.flatten()
            Q_flat = Q1.flatten()
            P_flat = P_flat / np.sum(P_flat)
            Q_flat = Q_flat / np.sum(Q_flat)
            self.kl_div = entropy(P_flat, Q_flat, base=np.e)


            if self.verbose:
                print("\n\n============= iteration: {}, eta =============".format(n_iter))
                print(eta)

            prev_res = self.res
            self.res = self._compute_residual(eta, beta,beta1,beta_complement)
            if self.verbose:
                print("n_iter: {}, Residual: {}".format(n_iter, self.res))
            print("n_iter: {}, Residual: {}".format(n_iter, self.res))
            # check convergence
            if (self.res <= self.tol) or (prev_res <= self.res and Constants.EPSILON.value <= prev_res):
                self.converged_n_iter = n_iter
                print("Convergence of theta at iteration: {}".format(self.converged_n_iter))
                break

            # compute \delta\eta and Fisher information matrix.
            delta_eta = eta - self.eta_hat

            eta_vec=np.zeros(len(beta1))
            for i, v in enumerate(beta):
                eta_vec[i] = eta[v] - self.eta_hat[v]

           
            eta_vec[-1] = np.sum([eta[s] - self.eta_hat[s] for s in beta_complement])

            G = self._compute_jacobian(eta, P_without_first_element_indices)
            G_bar = self._compute_jacobian_bar(beta1,G)

            if self.verbose:
                print("\n\n============= iteration: {}, delta_eta =============".format(n_iter))
                print(delta_eta)
                print("\n\n============= iteration: {}, eta_vec =============".format(n_iter))
                print(eta_vec)

            # TODO: Algorithm 7, Information Geometric Approaches for
            # Neural Network Algorithms to compute G inverse
            try:
                theta_vec -= 0.2*np.dot(np.linalg.inv(G_bar), eta_vec)
                #theta_vec -= 0.01 * np.linalg.solve(G_bar, eta_vec)
                theta_vec -= 0.00000000002*eta_vec

            except:
                theta_vec -= 0.01*np.dot(np.linalg.pinv(G_bar), eta_vec)

            if self.verbose:
                try:
                    G_inv_bar = np.linalg.inv(G_bar)
                except:
                    G_inv_bar = np.linalg.pinv(G_bar)
                print("\n\n============= iteration: {}, G_inverse =============".format(n_iter))
                print(G_inv_bar)
                print("\n\n============= iteration: {}, theta_vec =============".format(n_iter))
                print(theta_vec)

          
            for n, v in enumerate(beta):
                theta[v] = theta_vec[n]

            for idx in beta_complement:
                theta[tuple(idx)] = theta_vec[-1] 
            #print('c=',theta_vec[-1])
            if self.verbose:
                print("\n\n============= iteration: {}, theta =============".format(n_iter))
                print(theta)
            #print(theta)
        #print('legnetaimp=', eta)
        return theta, theta_vec[-1]

    def _legendre_decomposition(self, P, coordinates, coordinates1,coordinates_complement,ori_theta):

        self.shape = P.shape
        order = len(P.shape)

        # normalize tensor
        self.P = self._normalizer(P)
        self.beta = coordinates
        self.beta1 = coordinates1
        self.beta_complement=coordinates_complement
        if self.verbose:
            print("\n\n============= beta =============")
            print(self.beta)
            print("\n\n============= beta1 =============")
            print(self.beta1)
            print("\n\n============= beta_complement =============")
            print(self.beta_complement)

        if self.solver == 'ng':
            theta,c = self._fit_natural_gradient(self.P, self.beta,self.beta1,self.beta_complement,ori_theta)
        elif self.solver == 'gd':
            theta = self._fit_gradient_descent(self.P, self.beta,self.beta1,self.beta_complement)
        else:
            raise ValueError("Invalid solver {}.".format(self.solver))

        return theta,c

    def _compute_eta(self, Q):
        shape = Q.shape
        order = len(shape)
        eta = np.zeros(shape)
        # eta1 = np.zeros(Q.shape)
        if order == 2:
            for i, j in itertools.product(range(shape[0]), range(shape[1])):
                eta[i, j] = Q[np.arange(i, shape[0])][:, np.arange(j, shape[1])].sum()
        else:
            ranges = [range(s) for s in shape]
            for indices in itertools.product(*ranges):
                #print('indices=', indices)
                # slices = [slice(idx, s) for idx, s in zip(indices, shape)]
                slices = tuple(slice(idx, None) for idx in indices)
                #print('slices=', slices)
                eta[tuple(indices)] = Q[tuple(slices)].sum()
        return eta

    def _compute_Q(self, theta):
        idx = theta.shape
        order = len(theta.shape)
        theta_sum = np.zeros(theta.shape)

        ranges = [range(s) for s in idx]
        for indices in itertools.product(*ranges):
            #print('incides=', indices)
            slices = tuple(slice(None, irt + 1) for irt in indices)
            #print('slices=', slices)
            theta_sum[tuple(indices)] = np.sum(theta[slices])

        Q = np.exp(theta_sum)
        psi = Q.sum()
        Q /= psi

        return Q

    def _compute_jacobian(self, eta, beta):

        beta = np.array(beta)
        size = len(beta)
        n_dims = beta.shape[1]


        if eta.ndim != n_dims:
            raise ValueError(f"Eta must be a {n_dims}-dimensional array, but got {eta.ndim}-dimensional.")


        indices = np.indices((size, size))
        I, J = indices[0], indices[1]


        max_index = np.zeros((size, size, n_dims), dtype=int)


        for dim in range(n_dims):
            max_index[:, :, dim] = np.maximum(beta[I, dim], beta[J, dim])


        eta_max = eta[tuple(max_index[..., dim] for dim in range(n_dims))]


        eta_values = eta[tuple(beta[:, dim] for dim in range(n_dims))]


        eta_prod = eta_values.reshape(-1, 1) * eta_values


        g = eta_max - eta_prod

        return g




    def _compute_jacobian_bar(self,beta1,G):
        size = len(beta1)
        g_bar = np.zeros((size, size))
        rows, columns = G.shape
        for i, j in itertools.product(range(size), range(size)):
            if i != size - 1 and j != size - 1:  # Case 1
                g_bar[i, j] = G[i, j]
            elif i == size - 1 and j != size - 1:  # Case 2
                g_bar[i, j] = np.sum(G[i:, j])
            elif j == size - 1 and i != size - 1:  # Case 3
                g_bar[i, j] = np.sum(G[i, j:])
            else:  # Case 4
                g_bar[i, j] = G[np.arange(i, rows)][:, np.arange(j, rows)].sum()
        return g_bar
